package cn.caplike.demo.spring.security.dynamic.authorization.configuration.security;

import cn.caplike.data.redis.service.spring.boot.starter.RedisKey;
import cn.caplike.data.redis.service.spring.boot.starter.RedisService;
import cn.caplike.demo.spring.security.dynamic.authorization.configuration.security.exception.UserInfoIncompleteException;
import cn.caplike.demo.spring.security.dynamic.authorization.domain.entity.User;
import cn.caplike.demo.spring.security.dynamic.authorization.util.AccessTokenUtils;
import cn.caplike.demo.spring.security.dynamic.authorization.util.CsrfTokenUtils;
import cn.caplike.demo.spring.security.dynamic.authorization.util.RequestUtils;
import com.alibaba.fastjson.JSON;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.stereotype.Component;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.nio.charset.StandardCharsets;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * 基于 Redis 的 CSRF Token Repository
 *
 * @author LiKe
 * @version 1.0.0
 * @date 2020-05-12 16:53
 */
@Slf4j
@Component
public class CsrfTokenRedisRepository implements CsrfTokenRepository {

    /**
     * parameterName
     */
    private static final String CSRF_PARAMETER_NAME = "_csrf";

    /**
     * headerName, csrf-token 在请求头中的 HEADER-NAME
     */
    private static final String CSRF_HEADER_NAME = "X-CSRF-TOKEN";

    private static final String CSRF_TOKEN = SecurityConfiguration.CSRF_TOKEN;
    /**
     * {@link CsrfTokenRedisRepository#saveToken(CsrfToken, HttpServletRequest, HttpServletResponse)} 白名单, 放行 /auth/register 注册端点
     */
    private static final Set<String> IGNORING_SAVING_TOKEN_LIST = Stream.of(SecurityConfiguration.REGISTER_URI).collect(Collectors.toSet());
    /**
     * {@link RedisService}
     */
    private RedisService redisService;
    /**
     * {@link User}#name, 用户名
     */
    private String name;

    /**
     * 如果 generateToken 被调用了, 则该方法也会被调用
     */
    @SneakyThrows
    @Override
    public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) {
        log.debug("csrf filter: redis csrf token repository: save token");

        if (Objects.isNull(token)) {
            redisService.delete(RedisKey.builder().prefix(name).suffix(CSRF_TOKEN).build());
            return;
        }

        // 缓存 csrf-token
        final String csrfToken = token.getToken();
        redisService.setValue(RedisKey.builder().prefix(name).suffix(CSRF_TOKEN).build(), csrfToken);

        if (StringUtils.equals(RequestUtils.getQualifiedURI(request), SecurityConfiguration.LOGIN_URI)) {
            // 登录端点的请求在登录成功后设置 JWTAuthenticationFilter#successfulAuthentication
            return;
        }

        // 将 csrf-token 放入响应头
        response.setHeader(CSRF_TOKEN, csrfToken);
    }

    /**
     * 首先被调用
     */
    @SneakyThrows
    @Override
    public CsrfToken loadToken(HttpServletRequest request) {
        log.debug("csrf filter: redis csrf token repository: load token");

        final User user = JSON.parseObject(request.getInputStream(), StandardCharsets.UTF_8, User.class);
        if (Objects.nonNull(user)) {
            this.name = user.getName();
        } else {
            try {
                this.name = AccessTokenUtils.getSubject(
                        request.getHeader(AccessTokenUtils.AUTHORIZATION_HEADER).replaceFirst(AccessTokenUtils.BEARER_TOKEN_TYPE, StringUtils.EMPTY)
                );
            } catch (Exception ignored) {
                throw new UserInfoIncompleteException("Cannot retrieve user's name from request (neither inputStream nor header Authorization)");
            }
        }

        if (StringUtils.isBlank(this.name)) {
            throw new UserInfoIncompleteException("Cannot retrieve user's name from request (neither inputStream nor header Authorization)");
        }

        // ~ 避免触发 saveToken
        // 1. 如果是注册端点, 生成一个临时的 csrf-token.
        final String qualifiedURI = RequestUtils.getQualifiedURI(request);
        if (IGNORING_SAVING_TOKEN_LIST.contains(qualifiedURI)) {
            return generateToken(request);
        }
//        // 2. 如果 (非登录端点) 请求头中没有 X-CSRF-TOKEN, 生成一个临时的 csrf-token.
//        if (StringUtils.isBlank(request.getHeader(CSRF_HEADER_NAME)) && !StringUtils.equals(qualifiedURI, SecurityConfiguration.LOGIN_URI)) {
//            return generateToken(request);
//        }

        // 返回正常的 Token
        final String csrfToken = getCachedToken();
        return StringUtils.isBlank(csrfToken) ? null : new DefaultCsrfToken(CSRF_HEADER_NAME, CSRF_PARAMETER_NAME, csrfToken);
    }

    /**
     * 当 loadToken 返回 null 时, 会调用 generateToken 随后调用 saveToken
     */
    @Override
    public CsrfToken generateToken(HttpServletRequest request) {
        return new DefaultCsrfToken(CSRF_HEADER_NAME, CSRF_PARAMETER_NAME, CsrfTokenUtils.create());
    }

    /**
     * Description: 从缓存中获取 csrf-token
     *
     * @return java.lang.String
     * @author LiKe
     * @date 2020-05-13 09:54:16
     */
    private String getCachedToken() {
        final String csrfToken = redisService.getValue(RedisKey.builder().prefix(name).suffix(CSRF_TOKEN).build(), String.class);

        if (StringUtils.isNoneBlank(csrfToken)) {
            return csrfToken;
        }

        return StringUtils.EMPTY;
    }

    // ~ Autowired
    // -----------------------------------------------------------------------------------------------------------------

    @Autowired
    public void setRedisService(RedisService redisService) {
        this.redisService = redisService;
    }
}
